--- title: SRCNN+BN: Image super-resolution using deep convolutional networks + Batch Normalization keywords: fastai sidebar: home_sidebar ---
%reload_ext autoreload
%autoreload 2
%matplotlib inline
import sys
sys.path.append('..')
from superres.datasets import *
from superres.databunch import *
seed = 8610
random.seed(seed)
np.random.seed(seed)
train_hr = div2k_train_hr_crop_256
in_size = 256
out_size = 256
scale = 4
bs = 10
data = create_sr_databunch(train_hr, in_size=in_size, out_size=out_size, scale=scale, bs=bs, seed=seed)
print(data)
data.show_batch()
model = SRCNN_BN()
loss_func = MSELossFlat()
metrics = [m_psnr, m_ssim]
learn = Learner(data, model, loss_func=loss_func, metrics=metrics)
model_name = model.__class__.__name__
lr_find(learn)
learn.recorder.plot(suggestion=True)
lr = 1e-3
lrs = slice(lr)
epoch = 3
pct_start = 0.3
wd = 1e-3
save_fname = model_name
callbacks = [ShowGraph(learn), SaveModelCallback(learn, name=save_fname)]
learn.fit_one_cycle(epoch, lrs, pct_start=pct_start, wd=wd, callbacks=callbacks)
learn.show_results()
test_hr = set14_hr
il_test_x = ImageImageList.from_folder(test_hr, after_open=partial(after_open_image, size=out_size, scale=4, sizeup=True))
il_test_y = ImageImageList.from_folder(test_hr, after_open=partial(after_open_image, size=out_size))
_ = learn.load(save_fname)
sr_test(learn, il_test_x, il_test_y, model_name)
model
learn.summary()